Arquitecturas de Memoria Profunda: Redes Neuronales Recurrentes, LSTM y GRU

Fundamentos Teóricos e Implementación Modular en PyTorch

Autor/a

Ph.D. Pablo Eduardo Caicedo Rodríguez

Fecha de publicación

Invalid Date

1. La Dimensión Temporal en el Procesamiento de Señales

La modelización matemática de sistemas dinámicos requiere representaciones capaces de capturar dependencias temporales. A diferencia de los modelos estáticos, que asumen independencia entre observaciones, las señales biológicas, el lenguaje y las series financieras son intrínsecamente secuenciales. Las Redes Neuronales Recurrentes (RNN) introducen bucles de retroalimentación que permiten mantener un estado interno, funcionando como una memoria persistente a través del tiempo.


2. Redes Neuronales Recurrentes (RNN)

2.1 Arquitectura y Dinámica de Estado

Una RNN procesa secuencias mediante la aplicación recursiva de una función de transición. Conceptualmente, la red se “desenrolla” a lo largo de los pasos de tiempo \(T\), compartiendo los mismos parámetros en cada paso.

La actualización del estado oculto se define formalmente como:

\[h_t = \sigma_h (W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh})\]

Donde:

  • \(h_t \in \mathbb{R}^h\): Vector de estado oculto en el tiempo \(t\).
  • \(x_t \in \mathbb{R}^d\): Vector de entrada en el tiempo \(t\).
  • \(W_{hh} \in \mathbb{R}^{h\times h}\): Matriz de pesos recurrente que conecta el pasado con el presente.
  • \(\sigma_h\): Función de activación no lineal (usualmente \(\tanh\)).

La salida del sistema en cada instante se proyecta desde el estado oculto:

\[y_t = \sigma_y(W_{hy} h_t + b_y)\]

2.2 Limitaciones del Gradiente

Durante el entrenamiento mediante Backpropagation Through Time (BPTT), el cálculo del gradiente implica productos sucesivos de matrices Jacobianas:

\[\frac{\partial h_t}{\partial h_k} = \prod_{j=k+1}^t \text{diag}(\sigma'(z_j)) W_{hh}\]

El comportamiento espectral de \(W_{hh}\) determina la estabilidad del aprendizaje: 1. Desvanecimiento: Si el radio espectral es menor a 1, la señal de error decae exponencialmente, impidiendo capturar dependencias lejanas. 2. Explosión: Si es mayor a 1, los gradientes divergen, desestabilizando los pesos.


3. Arquitectura Long Short-Term Memory (LSTM)

La arquitectura LSTM introduce una celda de memoria (\(C_t\)) independiente del estado oculto (\(h_t\)), diseñada para preservar el flujo del gradiente mediante interacciones lineales.

3.1 Mecanismo de Compuertas

El flujo de información es regulado por estructuras sigmoidales:

  1. Olvido (\(f_t\)): Determina la información a descartar del estado de celda previo. \[f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)\]

  2. Entrada (\(i_t\)) y Candidato (\(\tilde{C}_t\)): Regulan la incorporación de nueva información. \[i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)\] \[\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)\]

  3. Actualización de Celda (\(C_t\)): Combina linealmente el pasado y el presente. \[C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\]

  4. Salida (\(o_t\)) y Estado Oculto: \[o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)\] \[h_t = o_t \odot \tanh(C_t)\]


4. Gated Recurrent Units (GRU)

La GRU optimiza la arquitectura recurrente fusionando los estados y reduciendo el sistema a dos compuertas de control, mejorando la eficiencia computacional sin sacrificar significativamente la capacidad de modelado.

4.1 Dinámica de Transición

  • Compuerta de Actualización (\(z_t\)): Controla cuánto del estado anterior se mantiene. \[z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z)\]

  • Compuerta de Reinicio (\(r_t\)): Determina la relevancia del pasado para el cálculo actual. \[r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r)\]

  • Estado Oculto (\(h_t\)): Interpolación directa. \[\tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h)\] \[h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\]

4.2 Selección de Arquitectura

  • LSTM: Preferible para secuencias con dependencias temporales muy extensas o cuando la separación entre memoria y estado oculto es crítica.
  • GRU: Ideal para sistemas con restricciones de cómputo, inferencia en tiempo real y conjuntos de datos limitados, debido a su menor número de parámetros.

5. Implementación Modular en PyTorch

La implementación se estructura mediante clases modulares que encapsulan la lógica recurrente, permitiendo la instanciación flexible de RNN, LSTM o GRU. Se asume una entrada tridimensional estandarizada: (Batch Size, Sequence Length, Features).

5.1 Estructura Genérica de Clasificación

Esta clase implementa un modelo “Muchos a Uno” (Many-to-One), donde la secuencia completa es procesada para generar una única salida basada en el último estado oculto.

import torch
import torch.nn as nn
from typing import Tuple, Union

class ModuloRecurrenteGenerico(nn.Module):
    """
    Bloque constructivo para procesamiento de secuencias temporales.
    Encapsula la lógica de selección de arquitectura recurrente.
    """
    def __init__(self,
                 tipo_arquitectura: str,
                 dim_entrada: int,
                 dim_oculta: int,
                 dim_salida: int):
        """
        Configuración de la arquitectura.

        Args:
            tipo_arquitectura (str): Identificador ('RNN', 'LSTM', 'GRU').
            dim_entrada (int): Número de características por paso de tiempo.
            dim_oculta (int): Tamaño del vector de estado interno.
            dim_salida (int): Dimensión del vector de salida final.
        """
        super(ModuloRecurrenteGenerico, self).__init__()
        self.tipo = tipo_arquitectura
        self.dim_oculta = dim_oculta

        # Inicialización agnóstica de la capa recurrente
        if self.tipo == 'RNN':
            self.rnn = nn.RNN(dim_entrada, dim_oculta, batch_first=True)
        elif self.tipo == 'LSTM':
            self.rnn = nn.LSTM(dim_entrada, dim_oculta, batch_first=True)
        elif self.tipo == 'GRU':
            self.rnn = nn.GRU(dim_entrada, dim_oculta, batch_first=True)
        else:
            raise ValueError("Arquitectura no soportada.")

        # Capa de proyección final
        self.proyeccion = nn.Linear(dim_oculta, dim_salida)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Procesamiento de la secuencia.

        Args:
            x (torch.Tensor): Entrada de forma (Batch, Seq_Len, Features).

        Returns:
            torch.Tensor: Salida proyectada del último estado temporal.
        """
        # Propagación a través de la capa recurrente
        if self.tipo == 'LSTM':
            # LSTM retorna: output, (hidden_state, cell_state)
            salida_secuencia, (h_n, c_n) = self.rnn(x)
        else:
            # GRU/RNN retornan: output, hidden_state
            salida_secuencia, h_n = self.rnn(x)

        # Extracción del estado correspondiente al último paso de tiempo (t=T)
        # salida_secuencia shape: (Batch, Seq_Len, Hidden_Dim)
        estado_final = salida_secuencia[:, -1, :]

        return self.proyeccion(estado_final)

5.2 Estructura Genérica de Generación Secuencial

Esta clase implementa un modelo “Muchos a Muchos” (Many-to-Many) autorregresivo, típico en tareas donde la salida en \(t\) depende de la historia hasta \(t\). Utiliza capas de Embedding para manejar entradas discretas.

class GeneradorSecuencial(nn.Module):
    """
    Modelo autorregresivo para generación de secuencias discretas.
    Soporta mantenimiento de estado entre pasos de inferencia.
    """
    def __init__(self,
                 tamano_vocabulario: int,
                 dim_embedding: int,
                 dim_oculta: int,
                 num_capas: int = 1):
        super(GeneradorSecuencial, self).__init__()
        self.dim_oculta = dim_oculta
        self.num_capas = num_capas

        # Transformación de índices discretos a espacio vectorial denso
        self.embedding = nn.Embedding(tamano_vocabulario, dim_embedding)

        # Núcleo recurrente (LSTM por defecto para memoria de largo plazo)
        self.lstm = nn.LSTM(dim_embedding, dim_oculta, num_capas, batch_first=True)

        # Decodificador al espacio original
        self.decodificador = nn.Linear(dim_oculta, tamano_vocabulario)

    def forward(self, x: torch.Tensor, estado_previo: Tuple[torch.Tensor, torch.Tensor]) \
            -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Paso de inferencia.

        Args:
            x: Índices de entrada (Batch, Seq_Len).
            estado_previo: Tupla con (hidden_state, cell_state) del paso t-1.

        Returns:
            output: Distribución de probabilidad no normalizada (logits).
            nuevo_estado: Estado actualizado para el paso t+1.
        """
        vectores = self.embedding(x)

        # Actualización recurrente
        salida_rnn, nuevo_estado = self.lstm(vectores, estado_previo)

        # Aplanamiento para procesamiento denso
        # (Batch * Seq_Len, Hidden_Dim)
        salida_aplanada = salida_rnn.reshape(-1, self.dim_oculta)

        output = self.decodificador(salida_aplanada)
        return output, nuevo_estado

    def inicializar_estado(self, batch_size: int, dispositivo: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        """Genera el estado cero inicial."""
        peso_ref = next(self.parameters()).data
        h_0 = peso_ref.new(self.num_capas, batch_size, self.dim_oculta).zero_().to(dispositivo)
        c_0 = peso_ref.new(self.num_capas, batch_size, self.dim_oculta).zero_().to(dispositivo)
        return (h_0, c_0)

6. Perspectiva Comparativa: RNN vs. Transformers

Aunque las arquitecturas basadas en Attention (Transformers) predominan en el modelado de lenguaje a gran escala, las RNN mantienen ventajas estructurales en dominios específicos:

  1. Complejidad de Inferencia: Las RNN operan con complejidad temporal \(O(N)\) y espacial \(O(1)\) respecto a la longitud de la historia, lo cual es crítico para sistemas embebidos (Edge AI) y procesamiento de señales en tiempo real (Streaming).
  2. Eficiencia de Datos: En regímenes de Small Data o señales con alta relación señal-ruido, arquitecturas como GRU suelen converger con mayor estabilidad que modelos masivos.

La elección de la arquitectura debe basarse en las restricciones de latencia, memoria disponible y la naturaleza causal de los datos.